import os
import megatron
import functools
import torch


def record_function_decorator(func):
    @functools.wraps(func)
    def new_func(*args, **kwargs):
        with torch.profiler.record_function(func.__name__):
            return func(*args, **kwargs)

    return new_func


original_forward_step = megatron.core.pipeline_parallel.schedules.forward_step
original_backward_step = megatron.core.pipeline_parallel.schedules.backward_step


@record_function_decorator
def forward_step(
    forward_step_func,
    data_iterator,
    model,
    num_microbatches,
    input_tensor,
    forward_data_store,
    config,
    collect_non_loss_data=False,
    checkpoint_activations_microbatch=None,
    is_first_microbatch=False,
    current_microbatch=None,
    encoder_decoder_xattn=False,
):
    return original_forward_step(
        forward_step_func,
        data_iterator,
        model,
        num_microbatches,
        input_tensor,
        forward_data_store,
        config,
        collect_non_loss_data,
        checkpoint_activations_microbatch,
        is_first_microbatch,
        current_microbatch,
        encoder_decoder_xattn,
    )


@record_function_decorator
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
    return original_backward_step(
        input_tensor, output_tensor, output_tensor_grad, model_type, config
    )


enable_profiler = int(os.getenv("ENABLE_PROFILER", 0))
if enable_profiler:
    megatron.core.pipeline_parallel.schedules.forward_step = forward_step
    megatron.core.pipeline_parallel.schedules.backward_step = backward_step
